import torch

from tools.utils import size_action_space, dim_action_space, dtype_action_space


class Batch():
    def __init__(self,args,obs_shape, action_space):
        self.args=args
        device=args.device
        dtype_act = dtype_action_space(action_space)
        dtype_goal = torch.float

        self.goals = torch.empty((args.batch_size, args.num_latents), device=device,dtype= dtype_goal)
        self.goals_obs=torch.zeros((args.batch_size, *obs_shape), device=device)
        self.goals_step = torch.empty((args.batch_size, 1), device=device,dtype=torch.long)

        if self.args.relabeling == 1 or self.args.relabeling == 2 or self.args.relabeling == 4 or self.args.relabeling2:
            self.label_obs = torch.empty((args.batch_size, *obs_shape), device=device)
        if not self.args.ratio_for_predictor:
            self.sac_train = torch.zeros((args.batch_size,),dtype=torch.bool,device=device)
        if self.args.state:
            self.states = torch.empty((args.batch_size, 2), device=device)
            self.prev_states = torch.empty((args.batch_size, 2), device=device)

        self.index = torch.empty(args.batch_size, dtype=torch.long)
        self.ind = torch.empty(args.batch_size,dtype=torch.long)
        self.obs = torch.empty((args.batch_size, *obs_shape), device=device)
        self.actions = torch.empty((args.batch_size, dim_action_space(action_space)),device=device, dtype=dtype_act)
        self.rewards = torch.zeros((args.batch_size, 1 ), device=device)
        self.next_obs = torch.empty((args.batch_size, *obs_shape), device=device)
        self.masks = torch.zeros((args.batch_size, 1), device=device)
        self.irewards = torch.zeros((args.batch_size, 1),device=device)

